# -*- coding: utf-8 -*- #
"""
Time                2023/3/19 12:16
Author:             mingfeng (SunnyQjm)
Email               mfeng@linux.alibaba.com
File                service_check.py
Description:
"""
import time
import requests
from abc import ABCMeta, abstractmethod, ABC
from enum import Enum, unique
from typing import Optional, Callable
from clogger import logger
from .exceptions import CmgException
from .utils import StoppableThread


@unique
class HealthState(Enum):
    OFFLINE = 0
    ONLINE = 1


class HealthCheckBase(metaclass=ABCMeta):

    def __init__(self, check_params: dict,
                 on_check: Optional[Callable[[HealthState], None]] = None,
                 do_unregister: Optional[Callable[[], None]] = None):
        self.check_params = check_params
        self.interval = check_params.get("interval", 10)
        self.deregister = check_params.get("deregister", 20)
        self._inner_thread: Optional[StoppableThread] = None
        self._on_check = on_check
        self._do_unregister = do_unregister
        self._last_success_time = time.time()

    @abstractmethod
    def check(self) -> HealthState:
        pass

    def on_check(self, state: HealthState):
        if self._on_check is not None:
            self._on_check(state)
        if state == HealthState.ONLINE:
            self._last_success_time = time.time()
        elif state == HealthState.OFFLINE:
            if time.time() - self._last_success_time > self.deregister:
                if self._do_unregister is not None:
                    self._do_unregister()
                # stop
                self.stop()

    def _do_health_check(self):
        self._last_success_time = time.time()
        while not self._inner_thread.stopped():
            time.sleep(self.interval)
            self.on_check(self.check())

    def start(self):
        if self._inner_thread is not None and \
                self._inner_thread.is_alive():
            return False
        self._inner_thread = StoppableThread(
            target=self._do_health_check,
            name="CMG-HEALTH-CHECK"
        )
        self._inner_thread.setDaemon(True)
        self._inner_thread.start()

    def stop(self) -> bool:
        if self._inner_thread is None:
            return False
        if not self._inner_thread.is_alive():
            self._inner_thread = None
            return False
        self._inner_thread.stop()
        return True

    def join(self):
        self._inner_thread.join()


class HTTPHealthCheck(HealthCheckBase):

    def __init__(self, check_params: dict,
                 on_check: Optional[Callable[[HealthState], None]] = None,
                 do_unregister: Optional[Callable[[], None]] = None):
        HealthCheckBase.__init__(self, check_params, on_check, do_unregister)
        self.url = check_params.get("url", "")
        self.timeout = check_params.get("timeout", 10)
        self.deregister = check_params.get("deregister", None)
        self.header = check_params.get("header", {})
        self.tls_skip_verify = check_params.get("tls_skip_verify", False)

    def check(self) -> HealthState:
        try:
            res = requests.get(
                self.url, headers=self.header, verify=not self.tls_skip_verify,
                timeout=self.timeout
            )
            if res.status_code == 200:
                return HealthState.ONLINE
        except requests.exceptions.RequestException as re:
            logger.error(re)
        return HealthState.OFFLINE


class DummyHealCheck(HealthCheckBase):
    """
    A simple health detection client implementation, within `duration` seconds,
    always return to the target service is online, more than `duration` seconds
    to return the target service is offline
    """

    def __init__(self, check_params: dict,
                 on_check: Optional[Callable[[HealthState], None]] = None,
                 do_unregister: Optional[Callable[[], None]] = None):
        HealthCheckBase.__init__(self, check_params, on_check, do_unregister)
        self.duration = check_params.get("duration", -1)
        self.startTime: float = -1

    def check(self) -> HealthState:
        if self.duration < 0:
            return HealthState.ONLINE
        if self.startTime < 0:
            self.startTime = time.time()
        if self.startTime + self.duration > time.time():
            return HealthState.ONLINE
        else:
            return HealthState.OFFLINE


class ServiceCheck:
    @staticmethod
    def create_health_check_instance(
            check_params: dict,
            on_check: Optional[Callable[[HealthState], None]] = None,
            do_unregister: Optional[Callable[[], None]] = None
    ) -> HealthCheckBase:
        check_type = check_params.get("type", None)
        if not check_type:
            raise CmgException(
                "ServiceCheck: Health check params require type filed, please "
                "use ServiceCheck.xxx to create health check params"
            )
            pass
        if check_type == "http":
            return HTTPHealthCheck(check_params, on_check, do_unregister)
        elif check_type == "dummy":
            return DummyHealCheck(check_params, on_check, do_unregister)
        else:
            raise CmgException(
                f"ServiceCheck: Not support check type => {check_type}"
            )

    @classmethod
    def http(cls, url: str, interval: int, timeout: Optional[int] = None,
             deregister: Optional[int] = None, header: Optional[dict] = None,
             tls_skip_verify: bool = False) -> dict:
        """
        Perform an HTTP GET against *url* every *interval* (e.g. "10s") to
        perform health check with an optional *timeout* and optional
        *deregister* after which a failing service will be automatically
        deregistered.

        Args:
            url(str):
            interval(int):
            timeout(int):
            deregister(int):
            header(dict):
            tls_skip_verify(bool):

        Returns:

        """
        res = {
            "type": "http",
            "url": url,
            "interval": interval,
            "header": {},
            "tls_skip_verify": tls_skip_verify,
        }

        if timeout is not None:
            res["timeout"] = timeout
        if deregister is not None:
            res["deregister"] = deregister
        if header is not None:
            res["header"] = header
        return res

    @classmethod
    def dummy(cls, duration: int, interval: int,
              deregister: Optional[int] = None) -> dict:
        """
        Build a construct args for DummyHealthCheck

        Args:
            duration:
            interval:
            deregister:

        Returns:

        """
        res = {
            "type": "dummy",
            "interval": interval,
            "duration": duration
        }
        if deregister is not None:
            res["deregister"] = deregister
        return res
